Skip to content

[tests] Review tests for PR #600#2

Closed
danielhanchen wants to merge 8 commits into
mainfrom
pr-600-tests
Closed

[tests] Review tests for PR #600#2
danielhanchen wants to merge 8 commits into
mainfrom
pr-600-tests

Conversation

@danielhanchen

Copy link
Copy Markdown
Owner

Automated test files from review process

Gemma-4 E2B/E4B GRPO training hits a CUDA device-side assert at step 2
when dtype=torch.float16 because the text-decoder MLP saturates in fp16:
gate_proj and up_proj outputs reach fp16_max under the fp16 autocast
context, the gate*up product overflows on downcast, and the subsequent
down_proj fp16 matmul accumulator tips to +inf. Generation with an inf
residual stream produces NaN logits and the categorical sampler aborts.

bf16 training is unaffected because the same intermediate magnitudes fit
in bf16's range. Mirror the Gemma-3 UNSLOTH_FORCE_FLOAT32 recipe with
four Gemma-4 specific patches, all gated on the env flag:

  * Gemma4RMSNorm: fp32 norm and scale, clamp to 65280 before fp16 cast
    (65280 is one bf16 ulp below fp16_max, avoiding the bf16 round-up
    that otherwise produces +-inf on later fp16 conversions).
  * Gemma4TextScaledWordEmbedding: embed lookup and scale in fp32.
  * Gemma4TextMLP: execute gate_proj and up_proj with autocast disabled
    so they see bf16 weights (the actual model dtype under
    UNSLOTH_FORCE_FLOAT32), run activation plus multiply in fp32, clamp
    to the safe fp16 bound, then down_proj in fp16. A final nan_to_num
    rescues the rare case where down_proj's fp16 accumulator overflows
    for wide intermediate dims.
  * Gemma4TextAttention: same autocast-disabled pattern for Q/K/V
    projections, fp32 q_norm/k_norm/v_norm, fp32 RoPE, fp32 SDPA, clamp
    and fp16 cast before o_proj, plus nan_to_num safety net.

Repro recipe (see PR description for full table):

  python scripts/gemma4_grpo_sudoku_repro.py \
    --model unsloth/gemma-4-E2B-it --dtype float16 \
    --max-steps 8 --grad-accum 4 --num-generations 2

Without the patches this crashes at step 2 with CUDA device-side assert
and overflow flagged in layers.0.mlp.down_proj. With the patches
unsloth/gemma-4-E2B-it and unsloth/gemma-4-E4B-it both complete 8 GRPO
steps with finite loss and grad_norm values close to the bf16 baseline,
under both gradient_checkpointing="unsloth" and gradient_checkpointing
=False. bf16 behavior is unaffected because the patches are gated on
UNSLOTH_FORCE_FLOAT32=1 which only fires for float16 requests.
Bisection showed the NaN chain originates exclusively in the text-decoder
MLP gate*up -> down_proj path:

  * Gemma4RMSNorm already computes internally in fp32 and returns
    type_as(hidden_states), which is finite for the trained weight ranges.
  * Gemma4TextScaledWordEmbedding scales by sqrt(hidden_size) which is
    ~45 to 60 for E2B/E4B and fits in fp16.
  * Gemma4TextAttention projections did not overflow in any run.

So the earlier RMSNorm, Embedding, and Attention patches were redundant
once the MLP is stabilized. Removing them lets the text pipeline keep its
original dtype contract (no autocast gymnastics, no KV cache dtype
mismatch), which in turn keeps the gate/up/down matmuls on fp16 tensor
cores.

The MLP patch itself is reduced to the three operations that matter:

  1. Compute act_fn(gate) * up in fp32 so the product cannot overflow.
  2. Clamp to 65280 (one bf16 ulp below fp16_max) before down_proj so
     the fp16 cast cannot produce +-inf.
  3. nan_to_num on the output as defense-in-depth for the rare fp16
     accumulator overflow in down_proj for wide intermediate dims.

Verified on B200 (fp16 autocast under GRPO) - all runs complete 8 steps
with no NaN and no bad params:

  | model                       | gc       | step1 grad | step8 grad | step8 kl |
  | unsloth/gemma-4-E2B-it      | unsloth  | 0.1110     | 0.4077     | 6.03e-05 |
  | unsloth/gemma-4-E2B-it      | off      | 0.1110     | 0.2655     | 1.18e-04 |
  | unsloth/gemma-4-E4B-it      | unsloth  | 0.0987     | 0.0000     | 1.47e-05 |

Tesla T4 compatibility: T4 has no bf16 tensor cores so the loader's
FORCE_FLOAT32 path lands on fp16 weights. The patch's fp32 gate*up is an
elementwise op (runs on CUDA cores at 8.1 TFLOPS); gate_proj, up_proj,
and down_proj matmuls stay on fp16 tensor cores (65 TFLOPS). Net perf
overhead is minimal and the overflow prevention cost is paid only once
per MLP forward.

UNSLOTH_ENABLE_FLEX_ATTENTION on Gemma-4: flex is enabled by default but
only used when sdpa is unavailable. Gemma-4 supports sdpa so flex is not
selected for the text decoder in practice. Setting the env var to 0
produces identical E4B trajectories and near-identical E2B (diverges only
at step 8, still finite, no crash).
The int8 branch (torch.ops.aten._weight_int8pack_mm) worked end-to-end
but failed the RL-parity goal: 100-step GRPO on unsloth/gemma-4-E2B-it
with temperature=0.05, min_p=0.5, seed=3407 gave total |KL| of 7.47e+06
for int8 vs 52.16 for fp16+clamp vs 9387 for bf16, and int8 step 7
already had grad_norm=NaN.

Cause: GRPO's log-pi-new - log-pi-old ratio amplifies the ~7%
per-matmul weight-quantization noise, because the rollout path runs
under torch.inference_mode (dense fp16) while the training forward
used int8. Making them symmetric would require a separate int8
reference model, which is out of scope for a surgical NaN fix.

This commit removes all int8 code (Int8LinearFn autograd wrapper, row
quantizer, PEFT LoRA-aware int8 forward, UNSLOTH_GEMMA4_MLP_INT8 env
switch) and keeps only the fp16 trick:

  - fp32 act_fn(gate) * up so the product cannot overflow
  - clamp to 65280 before down_proj
  - nan_to_num on down_proj output for the rare accumulator tail

The patched forward is now 7 effective lines. Dtype contract is
identical to upstream (input dtype -> input dtype), so no attention /
RMSNorm / embedding companion patches are needed and the KV cache
stays aligned.

100-step verification vs bf16 (temp=0.05, min_p=0.5, seed=3407):

  median |KL|   bf16 4.02e-05  fp16+clamp 7.99e-06 (0.20x)
  p95 |KL|      bf16 0.411     fp16+clamp 0.620    (1.51x)
  max |KL|      bf16 8529      fp16+clamp 35.77    (0.004x)
  total |KL|    bf16 9388      fp16+clamp 52.16    (0.006x)
  mean reward   bf16 1.2111    fp16+clamp 1.1359   (|d|=0.075)
  reward-equal  75/100 steps
  time          +0.3%

Note: bf16's two huge outlier steps (step 24 |KL|=8529 grad=26461;
step 42 |KL|=800 grad=1434) are what push its total KL above
fp16+clamp's. For the calm 63-66% of steps both trajectories track
each other to 1e-5 precision.
65280 is the largest value exactly representable in both fp16 and bf16
(one bf16 ULP below 65536, 224 below fp16_max=65504). The previous
wording claimed it was one bf16 ulp below fp16_max, but fp16_max=65504
is not representable in bf16 at all -- it rounds up to 65536.

Clamp value is unchanged; comment only.
…nan rescue, inline import

- Add `x.dtype != torch.float16` guard so bf16/fp32 activations pass
  through the upstream forward unchanged. Under UNSLOTH_FORCE_FLOAT32
  the weights are fp16 so this is normally unreachable, but the guard
  protects against pipeline configurations where autocast is disabled
  (rl_replacements nullcontext path) and prevents unintended clamping
  if activations reach the MLP as bf16.

- Change nan_to_num replacements from +-65280 to 0 so overflow
  positions contribute nothing and leave the identity residual intact.
  Replacing with the fp16 ceiling would otherwise dominate the O(1)
  post-RMSNorm hidden state. Backward gradient through nan_to_num is
  `grad * isfinite(input)` in both cases, so this is loss-free.

- Drop the one-off `_gemma4_modeling` helper and inline the import to
  match the pattern of every other patch in this file.

- Drop the `x: torch.Tensor` annotation to match the rest of the
  temporary patches (`def forward(self, x):`).
- Switch the stabilization guard from `x.dtype != torch.float16` to
  `gate.dtype != torch.float16`. The matmul output dtype is what
  determines whether overflow can occur, so this also catches
  mixed-precision cases (bf16 activations through fp16-cast weights via
  autocast or do_forced_float32) that the x.dtype check missed. For
  the standard UNSLOTH_FORCE_FLOAT32 path behavior is unchanged.

- Inline `self.up_proj(x).float()` so the fp16 up tensor is transient,
  and reuse the already-computed gate in the bypass path.

- Cast product back with `gate.dtype` instead of `x.dtype` to avoid a
  bf16-input/fp16-weight mismatch at down_proj if stabilization is
  active with non-fp16 activations.
@danielhanchen

Copy link
Copy Markdown
Owner Author

Fixes pushed to unslothai#600.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant